package manager

import (
	"context"
	"database/sql"
	"errors"
	"fmt"
	"gitee.com/xfrm/middleware/xsql/builder"
	"gitee.com/xfrm/middleware/xsql/scanner"
	"gitee.com/xfrm/middleware/xsql/xdb"
	"gitee.com/xfrm/middleware/xtime"
	"github.com/VividCortex/mysqlerr"
	"github.com/go-sql-driver/mysql"
	"github.com/opentracing/opentracing-go"
	"github.com/xwb1989/sqlparser"

	"gitee.com/xfrm/middleware/internal/breaker"
	"gitee.com/xfrm/middleware/xlog"
	"gitee.com/xfrm/middleware/xtrace"
)

const (
	traceComponent = "xsql"
)

var bCheckTableName = true

// DB 实现了XDB接口，同时可以通过GetTx获取一个Tx句柄并进行提交
type DB struct {
	db      *sql.DB
	cluster string
}

// ExecContext exec insert/update/delete and so on.
func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
	fun := "xsql.DB.ExecContext"
	table := db.fetchTableName(query)
	// check breaker
	if !breaker.Entry(db.cluster, table) {
		xlog.Errorf(ctx, "%s trigger tidb breaker, because too many timeout sqls, cluster: %s, table: %s", fun, db.cluster, table)
		return nil, errors.New("sql cause breaker, because too many timeout")
	}
	// trace
	span, ctx := xtrace.StartSpanFromContext(ctx, fun)
	defer span.Finish()
	query = injectSQLTraceIDLineComment(ctx, query)
	setDBSpanTags(span, db.cluster, table, fmt.Sprintf("%s %v", query, args))

	st := xtime.NewTimeStat()
	res, err := db.db.ExecContext(ctx, query, args...)
	statMetricReqDur(ctx, db.cluster, table, "exec", st.Millisecond())
	// stat breaker
	breaker.StatBreaker(db.cluster, table, err)
	statMetricReqErrTotal(ctx, db.cluster, table, "exec", err)
	return res, err
}

// QueryContext executes a query that returns rows, typically a SELECT.
func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
	fun := "xsql.DB.QueryContext"
	table := db.fetchTableName(query)
	// check breaker
	if !breaker.Entry(db.cluster, table) {
		xlog.Errorf(ctx, "%s trigger tidb breaker, because too many timeout sqls, cluster: %s, table: %s", fun, db.cluster, table)
		return nil, errors.New("sql cause breaker, because too many timeout")
	}
	// trace
	span, ctx := xtrace.StartSpanFromContext(ctx, fun)
	defer span.Finish()
	query = injectSQLTraceIDLineComment(ctx, query)
	setDBSpanTags(span, db.cluster, table, fmt.Sprintf("%s %v", query, args))

	st := xtime.NewTimeStat()
	res, err := db.db.QueryContext(ctx, query, args...)
	statMetricReqDur(ctx, db.cluster, table, "query", st.Millisecond())
	// stat breaker
	breaker.StatBreaker(db.cluster, table, err)
	statMetricReqErrTotal(ctx, db.cluster, table, "query", err)
	return res, err
}

// QueryRowContext executes a query that is expected to return at most one row.
func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
	fun := "xsql.DB.QueryRowContext"
	table := db.fetchTableName(query)
	// check breaker
	if !breaker.Entry(db.cluster, table) {
		xlog.Errorf(ctx, "%s trigger tidb breaker, because too many timeout sqls, cluster: %s, table: %s", fun, db.cluster, table)
		return nil
	}
	// trace
	span, ctx := xtrace.StartSpanFromContext(ctx, fun)
	defer span.Finish()
	query = injectSQLTraceIDLineComment(ctx, query)
	setDBSpanTags(span, db.cluster, table, fmt.Sprintf("%s %v", query, args))

	st := xtime.NewTimeStat()
	res := db.db.QueryRowContext(ctx, query, args...)
	statMetricReqDur(ctx, db.cluster, table, "query row", st.Millisecond())
	return res
}
func (db *DB) SelectOne(ctx context.Context, table string, where map[string]interface{}, item interface{}) error {
	if nil == db {
		return errors.New("manager.XDB object couldn't be nil")
	}
	copyWhere := copyWhere(where)
	if _, ok := copyWhere["_limit"]; !ok {
		copyWhere["_limit"] = []uint{0, 1}
	}
	cond, vals, err := builder.BuildSelectWithContext(ctx, table, copyWhere, nil)
	if nil != err {
		return err
	}
	xlog.Debugf(ctx, "build cond: %s vals: %v", cond, vals)
	row, err := db.QueryContext(ctx, cond, vals...)
	if nil != err || nil == row {
		return err
	}
	defer row.Close()
	err = scanner.Scan(row, item)
	return err
}

func (db *DB) Select(ctx context.Context, table string, where map[string]interface{}, items interface{}) error {
	if nil == db {
		return errors.New("manager.XDB object couldn't be nil")
	}
	cond, vals, err := builder.BuildSelectWithContext(ctx, table, where, nil)
	if nil != err {
		return err
	}
	xlog.Debugf(ctx, "build cond: %s vals: %v", cond, vals)
	row, err := db.QueryContext(ctx, cond, vals...)
	if nil != err || nil == row {
		return err
	}
	defer row.Close()
	err = scanner.Scan(row, items)
	return err
}

func (db *DB) Insert(ctx context.Context, table string, data []map[string]interface{}) (int64, error) {
	if nil == db {
		return 0, errors.New("manager.XDB object couldn't be nil")
	}
	cond, vals, err := builder.BuildInsert(table, data)
	if nil != err {
		return 0, err
	}
	xlog.Debugf(ctx, "build cond: %s vals: %v", cond, vals)
	result, err := db.ExecContext(ctx, cond, vals...)
	if nil != err || nil == result {
		return 0, err
	}
	return result.LastInsertId()
}

func (db *DB) Update(ctx context.Context, table string, where, data map[string]interface{}) (int64, error) {
	if nil == db {
		return 0, errors.New("manager.XDB object couldn't be nil")
	}
	cond, vals, err := builder.BuildUpdate(table, where, data)
	if nil != err {
		return 0, err
	}
	xlog.Debugf(ctx, "build cond: %s vals: %v", cond, vals)
	result, err := db.ExecContext(ctx, cond, vals...)
	if nil != err {
		return 0, err
	}
	return result.RowsAffected()
}

func (db *DB) Upsert(ctx context.Context, table string, where, data, insertSet map[string]interface{}) (int64, error) {
	if nil == db {
		return 0, errors.New("manager.XDB object couldn't be nil")
	}
	var insertData = make(map[string]interface{})
	for k, v := range data {
		insertData[k] = v
	}
	for k, v := range insertSet {
		insertData[k] = v
	}
	lastID, err := db.Insert(ctx, table, []map[string]interface{}{insertData})
	if err == nil {
		return lastID, nil
	}
	mysqlErr, ok := err.(*mysql.MySQLError)
	if ok && mysqlerr.ER_DUP_ENTRY == int(mysqlErr.Number) {
		return db.Update(ctx, table, where, data)
	}
	return lastID, err
}

func (db *DB) Delete(ctx context.Context, table string, where map[string]interface{}) (int64, error) {
	if nil == db {
		return 0, errors.New("manager.XDB object couldn't be nil")
	}
	cond, vals, err := builder.BuildDelete(table, where)
	if nil != err {
		return 0, err
	}
	xlog.Debugf(ctx, "build cond: %s vals: %v", cond, vals)
	result, err := db.ExecContext(ctx, cond, vals...)
	if nil != err {
		return 0, err
	}
	return result.RowsAffected()
}

func (db *DB) SelectCount(ctx context.Context, table string, where map[string]interface{}) (count int, err error) {
	if nil == db {
		return 0, errors.New("manager.XDB object couldn't be nil")
	}
	cond, vals, err := builder.BuildSelect(table, where, []string{builder.AggregateCount("*").Symble()})
	if nil != err {
		return 0, err
	}
	xlog.Debugf(ctx, "build cond: %s vals: %v", cond, vals)
	result, err := db.QueryContext(ctx, cond, vals...)
	if nil != err {
		return 0, err
	}
	for result.Next() {
		err = result.Scan(&count)
		if err != nil {
			return
		}
	}
	return
}

func (db *DB) TxExecute(ctx context.Context, exec func(db xdb.XDB) error) error {
	tx, err := db.Begin(ctx)
	if err != nil {
		xlog.Errorf(ctx, "start transaction error :  err:%v", err)
		return err
	}
	err = exec(tx)
	if err != nil {
		er := tx.Rollback(ctx)
		if er != nil {
			xlog.Errorf(ctx, "rollback transaction error :  err:%v", er)
		}
		return err
	}
	err = tx.Commit(ctx)
	if err != nil {
		xlog.Errorf(ctx, "commit transaction error :  err:%v", err)
		return err
	}
	return nil
}

// GetTx 获取sql.Tx
// Deprecated: 无法进行trace、打点，建议使用下述db.Begin()
func (db *DB) GetTx() (*sql.Tx, error) {
	return db.db.Begin()
}

// GetSQLDB 获取sql.DB
// Deprecated: 无法进行trace、打点，无事务场景建议直接使用db.QueryContext等函数
func (db *DB) GetSQLDB() *sql.DB {
	return db.db
}

//SetSQLDB mock时使用
func (db *DB) SetSQLDB(outdb *sql.DB) {
	db.db = outdb
}

// Begin return Tx, wrapper of sql.Tx
func (db *DB) Begin(ctx context.Context) (*Tx, error) {
	var err error
	tx := &Tx{cluster: db.cluster}
	// trace
	span, ctx := xtrace.StartSpanFromContext(ctx, "xsql.DB.Begin")
	defer span.Finish()
	setDBSpanTags(span, tx.cluster, tx.cluster, "")

	st := xtime.NewTimeStat()
	tx.tx, err = db.db.Begin()
	statMetricReqDur(ctx, tx.cluster, tx.cluster, "begin", st.Millisecond())
	statMetricReqErrTotal(ctx, tx.cluster, tx.cluster, "begin", err)
	return tx, err
}

func (db *DB) fetchTableName(query string) (table string) {
	table = extractSQLTableName(query)
	if table != "" {
		return
	}

	if db != nil {
		return db.cluster
	}

	return
}

func setDBSpanTags(span opentracing.Span, cluster, table, stmt string) {
	span.SetTag(xtrace.TagComponent, traceComponent)
	span.SetTag(xtrace.TagDBType, xtrace.DBTypeSQL)
	span.SetTag(xtrace.TagSpanKind, xtrace.SpanKindClient)
	setTagIfNonEmpty(span, xtrace.TagPalfishDBCluster, cluster)
	setTagIfNonEmpty(span, xtrace.TagPalfishDBTable, table)
	setTagIfNonEmpty(span, xtrace.TagDBSQLTable, table)
	setTagIfNonEmpty(span, xtrace.TagDBStatement, stmt)
}

func setTagIfNonEmpty(span opentracing.Span, key, val string) {
	if val != "" {
		span.SetTag(key, val)
	}
}

func extractSQLTableName(query string) (table string) {
	stmt, err := sqlparser.Parse(query)
	if err != nil {
		return ""
	}

	_ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
		if tableIdent, ok := node.(sqlparser.TableIdent); ok {
			table = tableIdent.String()
			if table != "" {
				return false, fmt.Errorf("has found")
			}
		}

		return true, nil
	}, stmt)

	return
}

func injectSQLTraceIDLineComment(ctx context.Context, query string) string {
	traceID := xtrace.ExtractTraceID(ctx)
	if traceID == "" {
		return query
	}

	return fmt.Sprintf("/*%s*/ %s", traceID, query)
}

func IsNotFound(err error) bool {
	return err == scanner.ErrEmptyResult
}

func copyWhere(src map[string]interface{}) (target map[string]interface{}) {
	target = make(map[string]interface{})
	for k, v := range src {
		target[k] = v
	}
	return
}
